import torch
from torch.utils.data import DataLoader
from torch.autograd import Variable
from CreateCNN import CreateCNN
from CreateResNet import CreateResNet
from DataSet import Dataset
import copy

train_size = 50000
epoch = 1200
ModelPath = 'Saved01/SGD11/epoch{}.pt'.format(epoch)

def ClassAccuracies(ModelArchitecture, ModelPath):
    eval_bs = 100
    test_bs = 100
    if ModelArchitecture == 'CNN':
        model = CreateCNN()
        model.load_state_dict(torch.load(ModelPath))
    elif ModelArchitecture == 'ResNet':
        model = CreateResNet()
        model.load_state_dict(torch.load(ModelPath))

    train_data, test_data = Dataset(train_size)
    if torch.cuda.is_available():
        print("Working on GPU")
    else:
        print("Working on CPU")
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # Define the batch size
    eval_data = copy.deepcopy(train_data)
    loaders = {
        'test': DataLoader(test_data,
                                            batch_size=test_bs,
                                            shuffle=True,
                                            num_workers=8,
                                            pin_memory=True),
    }

    model.eval()
    with torch.no_grad():
        class_correct = [0] * 10
        class_total = [0] * 10
        for images, labels in loaders['test']:
            images = images.to(device)
            labels = labels.to(device)
            # outputs = cnn(images)[0]
            test_outputs = model(images)
            pred_y = torch.max(test_outputs, 1)[1].data.squeeze()
            c = (pred_y == labels).squeeze()

            for i in range(len(labels)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1

        for i in range(10):  # Assuming there are 10 classes in CIFAR-10
            class_accuracy = class_correct[i] / class_total[i] if class_total[i] != 0 else 0
            print(f'Accuracy for Class {i}: {100 * class_accuracy:.2f}%')

ClassAccuracies('ResNet', ModelPath=ModelPath)
